{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "746c6089-658a-46b0-becd-44ed59f24ebe",
   "metadata": {},
   "source": [
    "# Animal Mixer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fa554eb-db7f-486c-971b-98fae51107bd",
   "metadata": {},
   "source": [
    "Given two animal species, let's make a cross between them and visualize the resulting new animal."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e8c89b2-b4e8-48bb-9e2b-4455a5dd5a6e",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c063afce-a8e9-48cf-a08e-d70db2bb62e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import gradio as gr\n",
    "import base64\n",
    "from io import BytesIO\n",
    "from PIL import Image\n",
    "from IPython.display import Audio, display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab174215-1029-40df-9d75-a30f1c399fc9",
   "metadata": {},
   "source": [
    "## Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b91a133e-becc-45ee-ad4c-6d3469c78826",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv(override=True)\n",
    "\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "    \n",
    "MODEL = \"gpt-4o-mini\"\n",
    "openai = OpenAI()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e696d093-3b8b-4275-939c-53c7b623469b",
   "metadata": {},
   "source": [
    "## System Messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f293608-376e-4f91-afce-e9d93787db03",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"You are a famous zoologist-surgeon who makes crosses between animals, so new hybrid animals with mixed features of both original animals. \"\n",
    "system_message += \"Given two animal species, you create a new species which is a hybrid between the two. Make sure it only has one head. \"\n",
    "system_message += \"Describe the new species following the pattern: species X is a hybrid between species A and species B. \"\n",
    "system_message += \"Species A and B are the two given species. Describe the new species briefly, in up to 3 sentences. \"\n",
    "system_message += \"Always be accurate. If you don't know the answer, say so.\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "418e9bdd-6a94-4054-8e5b-0431866572ab",
   "metadata": {},
   "source": [
    "## Tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8175d1d3-1df8-4e76-8437-cdf5bb32c6db",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_animal_name(animal1, animal2):\n",
    "    print(f\"Tool get_animal_name called for the cross between {animal1} and {animal2}\")\n",
    "    first = len(animal1) // 2\n",
    "    second = len(animal2) // 2\n",
    "    name = animal1[:first] + animal2[second:]\n",
    "    return name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f56b440-0e5a-42dc-94ed-c8e0ea37fb03",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_animal_name('capybara', 'elephant')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d57a1b3a-e1a8-47bc-98b9-ed945346ebf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "animal_function = {\n",
    "    \"name\": \"get_animal_name\",\n",
    "    \"description\": \"Get the name of the cross between the two given animals. Call this whenever you are given the names of the two original animals, for example when a user enters 'capybara' and 'elephant'\",\n",
    "    \"parameters\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"animal1\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"the first original animal species\",\n",
    "            },\n",
    "            \"animal2\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"the second original animal species\",\n",
    "            },\n",
    "        },\n",
    "        \"required\": [\"animal1\", \"animal2\"],\n",
    "        \"additionalProperties\": False\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "376b79be-8481-4907-a109-99c64c7aa126",
   "metadata": {},
   "outputs": [],
   "source": [
    "tools = [{\"type\": \"function\", \"function\": animal_function}]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9909a074-ba87-43a6-bb4b-bc432be951ed",
   "metadata": {},
   "source": [
    "## Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c954476-0dab-4303-8129-0a48c64d408c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def artist(animal1, animal2):\n",
    "    image_response = openai.images.generate(\n",
    "            model=\"dall-e-3\",\n",
    "            prompt=f\"An image representing a hybrid between {animal1} and {animal2}, with some features of {animal1} and some features of {animal2}, blended smoothly into a single hybrid animal, in photorealistic style. Make sure it only has one head and there is no text in the image.\",\n",
    "            size=\"1024x1024\",\n",
    "            n=1,\n",
    "            response_format=\"b64_json\",\n",
    "        )\n",
    "    image_base64 = image_response.data[0].b64_json\n",
    "    image_data = base64.b64decode(image_base64)\n",
    "    return Image.open(BytesIO(image_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f219f63-e44b-41e5-aa63-62a62356e067",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = artist(\"capybara\", \"elephant\")\n",
    "display(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36d83bb3-55e3-4985-95ff-6e32ddb6cd9e",
   "metadata": {},
   "source": [
    "## Audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f87393d0-4da1-4647-8f70-13df63a283b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def talker(message):\n",
    "    response = openai.audio.speech.create(\n",
    "        model=\"tts-1\",\n",
    "        voice=\"onyx\",\n",
    "        input=message)\n",
    "\n",
    "    audio_stream = BytesIO(response.content)\n",
    "    output_filename = \"output_audio.mp3\"\n",
    "    with open(output_filename, \"wb\") as f:\n",
    "        f.write(audio_stream.read())\n",
    "\n",
    "    # Play the generated audio\n",
    "    display(Audio(output_filename, autoplay=True))\n",
    "\n",
    "talker(\"Well, hi there\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cebfd9e2-d633-400f-8c4f-f152dfda3eea",
   "metadata": {},
   "source": [
    "## Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8883a76a-4ff5-47b3-95f8-83f9de7158af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(history):\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
    "    response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
    "    image = None\n",
    "    \n",
    "    if response.choices[0].finish_reason==\"tool_calls\":\n",
    "        message = response.choices[0].message\n",
    "        response, animal1, animal2 = handle_tool_call(message)\n",
    "        messages.append(message)\n",
    "        messages.append(response)\n",
    "        image = artist(animal1, animal2)\n",
    "        response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
    "        \n",
    "    reply = response.choices[0].message.content\n",
    "    history += [{\"role\":\"assistant\", \"content\":reply}]\n",
    "\n",
    "    # Comment out or delete the next line if you'd rather skip Audio for now..\n",
    "    talker(reply)\n",
    "    \n",
    "    return history, image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d185b213-d059-4a1d-9900-3bd6e9474a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_tool_call(message):\n",
    "    tool_call = message.tool_calls[0]\n",
    "    arguments = json.loads(tool_call.function.arguments)\n",
    "    animal1 = arguments.get('animal1')\n",
    "    animal2 = arguments.get('animal2')\n",
    "    animal_name = get_animal_name(animal1, animal2)\n",
    "    response = {\n",
    "        \"role\": \"tool\",\n",
    "        \"content\": json.dumps({\"animal1\": animal1, \"animal2\": animal2, \"animal_name\": animal_name}),\n",
    "        \"tool_call_id\": tool_call.id\n",
    "    }\n",
    "    return response, animal1, animal2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "045085a2-2267-4069-a57d-3a9e6695b272",
   "metadata": {},
   "source": [
    "## Gradio UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87f5f97-109e-40ab-a95b-fb63b21abc11",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as ui:\n",
    "    with gr.Row():\n",
    "        chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
    "        image_output = gr.Image(height=500)\n",
    "    with gr.Row():\n",
    "        entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
    "    with gr.Row():\n",
    "        clear = gr.Button(\"Clear\")\n",
    "\n",
    "    def do_entry(message, history):\n",
    "        history += [{\"role\":\"user\", \"content\":message}]\n",
    "        return \"\", history\n",
    "\n",
    "    entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n",
    "        chat, inputs=chatbot, outputs=[chatbot, image_output]\n",
    "    )\n",
    "    clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n",
    "\n",
    "ui.launch(inbrowser=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a766324-5b75-4624-92d0-60ced31dcd26",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
